from .dfaopt import State, Symbol, DFA, DFAOpt
from dpsim.network import FatTreeProvider

import confuse

def find_ft_dfa(src, dst, g, net):
    src_pod, src_tor = src
    dst_pod, dst_tor = dst

    aggr_pod = int(net.K / 2)

    #print('(%d, %d) -> (%d, %d)' % (src_pod, src_tor, dst_pod, dst_tor))

    states = []
    transitions = {}
    if src_pod == dst_pod:
        # intra pod
        tor_1 = State(1, net.tors[src_pod][src_tor])
        aggr_2 = [State(2, net.aggrs[src_pod][i]) for i in range(aggr_pod)]
        tor_3 = State(3, net.tors[dst_pod][dst_tor])

        transitions[tor_1] = {
            Symbol(i): aggr_2[i] for i in range(aggr_pod)
        }
        for i in range(aggr_pod):
            transitions[aggr_2[i]] = {
                Symbol(-dst_tor): tor_3
            }
        return DFA(
            states = [ tor_1 ] + aggr_2 + [ tor_3 ],
            transitions = transitions,
            initials = [ tor_1 ],
            acceptings = [ tor_3 ]
        )
    else:
        # inter-pod
        tor_1 = State(1, net.tors[src_pod][src_tor])
        aggr_2 = [State(2, net.aggrs[src_pod][i]) for i in range(aggr_pod)]
        core_3 = [State(3, net.cores[i]) for i in range(net.n_cores)]
        aggr_4 = [State(4, net.aggrs[dst_pod][i]) for i in range(aggr_pod)]
        tor_5 = State(5, net.tors[dst_pod][dst_tor])

        transitions[tor_1] = {
            Symbol(i): aggr_2[i] for i in range(aggr_pod)
        }
        for j in range(aggr_pod):
            a2 = aggr_2[j]
            a4 = aggr_4[j]
            for k in range(aggr_pod):
                ci = k * aggr_pod + j
                c3 = core_3[ci]
                if a2 not in transitions:
                    transitions[a2] = {}
                if c3 not in transitions:
                    transitions[c3] = {}
                transitions[a2][Symbol(k)] = c3
                transitions[c3][Symbol(-dst_pod)] = a4
        for i in range(aggr_pod):
            transitions[aggr_4[i]] = {
                Symbol(-dst_tor): tor_5
            }
        return DFA(
            states = [tor_1] + aggr_2 + core_3 + aggr_4 + [tor_5],
            transitions = transitions,
            initials = [ tor_1 ],
            acceptings = [ tor_5 ]
        )

def gen_ft_dfa(k):
    """
    Generate a DFA from
    """
    config = confuse.Configuration("dpsim", __name__)
    config['k'] = k

    network = FatTreeProvider(config)
    g = network.generate()

    n_pods = k
    tor_per_pod = int(k / 2)

    tors = [(i, j) for i in range(n_pods) for j in range(tor_per_pod)]

    dfas = []
    for src in tors:
        for dst in tors:
            if src == dst:
                continue
            dfas += [find_ft_dfa(src, dst, g, network)]
    return dfas

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description='DFA compression for Fat tree')
    parser.add_argument('-k', dest='k',
                        type=int, default=4,
                        help='Parameter to the Fat-tree')

    args = parser.parse_args()

    dfas = gen_ft_dfa(args.k)

    opt = DFAOpt()

    initials, states, transitions = opt.optimize(dfas)

    ns_before = sum(map(lambda d: d.n_states, dfas))
    ts_before = sum(map(lambda d: d.n_transitions, dfas))
    ns_after = len(states)
    ts_after = sum([len(transitions.get(s, {})) for s in states])
    print('Fat Tree (%d)' % args.k, ',', '#state', ',',
          ns_before, ',', ns_after, ',', ns_after / ns_before)
    print('Fat Tree (%d)' % args.k, ',', '#transition', ',',
          ts_before, ',', ts_after, ',', ts_after / ts_before)
